package face2wap;

import deeplearn.Gan;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Map;
import java.util.Random;

/**
 * Created by susaneraly on 6/9/16.
 */
@Slf4j
public class ImgGanPipelineExample {

    static String model = "F:/face/ganimg.zip";

    static int height = 60; // 输入图像高度

    static int width = 60; // 输入图像宽度

    static int channels = 3; // 输入图像通道数

    static int outputNum = 1; // 2分类

    static int batchSize = 64;

    static int nEpochs = 1000000;

    static int seed = 1234;

    static Random randNumGen = new Random(seed);
    static double lr = 0.01;

    public static void main(String[] args) throws Exception {

        String inputDataDir = "F:/face/yzm";
        File trainData = new File(inputDataDir + "/train");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // parent path as the image label
        ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
        trainRR.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
        // 将像素从0-255缩放到0-1 (用min-max的方式进行缩放)
       // DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
       // scaler.fit(trainIter);
       // trainIter.setPreProcessor(scaler);

        File testData = new File(inputDataDir + "/test");
        FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
        testRR.initialize(testSplit);
        DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
      //  testIter.setPreProcessor(scaler); // same normalization for better results

        // 设置网络层及超参数

        final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
                .weightInit(WeightInit.XAVIER);

        final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard)
                .addInputs("input1", "input2")
                //.addInputs("input1")
                .setInputTypes(InputType.feedForward(height * width * channels), InputType.feedForward(height * width * channels))
                .addLayer("g1",
                        new DenseLayer.Builder().nOut(128).activation(Activation.RELU)
                                .weightInit(WeightInit.XAVIER).build(),
                        "input1")
                .addLayer("g2",
                        new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
                                .weightInit(WeightInit.XAVIER).build(),
                        "g1")
                .addLayer("g3",
                        new DenseLayer.Builder().nIn(512).nOut(height * width * channels).activation(Activation.RELU)
                                .weightInit(WeightInit.XAVIER).build(),
                        "g2")
                .addVertex("stack", new StackVertex(), "input2", "g3")
                .addLayer("d1",
                        new DenseLayer.Builder().nIn(height * width * channels).nOut(256).activation(Activation.RELU)
                                .weightInit(WeightInit.XAVIER).build(),
                        "stack")
                .addLayer("d2",
                        new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
                                .weightInit(WeightInit.XAVIER).build(),
                        "d1")
                .addLayer("d3",
                        new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
                                .weightInit(WeightInit.XAVIER).build(),
                        "d2")
                .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
                        .activation(Activation.SIGMOID).build(), "d3")
                .setOutputs("out");
        ComputationGraph net = null;
        if(new File(model).exists()){
            net = ComputationGraph.load(new File(model), true);
        }else {
            net = new ComputationGraph(graphBuilder.build());
        }
        // 新建一个多层网络模型

       // MultiLayerNetwork net = new MultiLayerNetwork(conf);

        net.init();
        net.setListeners(new ScoreIterationListener(100));
        System.out.println(net.summary());
        // 训练的过程中同时进行评估

        INDArray labelD = Nd4j.ones(128, 2);
        INDArray labelG = Nd4j.ones(128, 2);
        for (int i = 0; i < nEpochs; i++) {
            //net.fit(trainIter);
            INDArray trueExp = trainIter.next().getFeatures();
            trueExp = trueExp.reshape(64,height * width * channels);
            INDArray z = Nd4j.ones( trueExp.shape());
            MultiDataSet dataSetD = new MultiDataSet(new INDArray[] {z,trueExp},
                    new INDArray[] { labelD });

            for(int m=0;m<10;m++){
                trainD(net, dataSetD);
            }

            z =  Nd4j.zeros( trueExp.shape());
            MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z, trueExp },
                    new INDArray[] { labelG });
            trainG(net, dataSetG);

            if (i % 10 == 1) {
                System.out.println("Iteration " + i/10+ " Visualizing...");
                Map<String, INDArray> map = net.feedForward(
                        new INDArray[] { Nd4j.rand( trueExp.shape()), trueExp }, false);
                INDArray indArray = map.get("g3");// .reshape(20,28,28);
                visualize(indArray);
            }

            if (i % 1000 == 0) {
                 net.save(new File(model), true);
            }

           /* log.info("Completed epoch " + i);

            Evaluation trainEval = net.evaluate(trainIter);

            Evaluation eval = net.evaluate(testIter);

            log.info("train: " + trainEval.precision());

            log.info("val: " + eval.precision());*/

            trainIter.reset();

            testIter.reset();

        }

        //保存模型

        ModelSerializer.writeModel(net, new File(inputDataDir + "/mouth-model.zip"), true);
    }
    // 判别模型  D(x)
    public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
        net.setLearningRate("g1", 0);
        net.setLearningRate("g2", 0);
        net.setLearningRate("g3", 0);
        net.setLearningRate("d1", lr);
        net.setLearningRate("d2", lr);
        net.setLearningRate("d3", lr);
        net.setLearningRate("out", lr);
        net.fit(dataSet);
    }
    //生成模型 g(z)
    public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
        net.setLearningRate("g1", lr);
        net.setLearningRate("g2", lr);
        net.setLearningRate("g3", lr);
        net.setLearningRate("d1", 0);
        net.setLearningRate("d2", 0);
        net.setLearningRate("d3", 0);
        net.setLearningRate("out", 0);
        net.fit(dataSet);
    }

    private static JFrame frame;
    private static JPanel panel;
    private static void visualize(INDArray samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();
            panel.setLayout(new GridLayout(0, 5));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        panel.add(getImage(samples));


        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        BufferedImage bi = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
        for (int i = 0; i < 3600; i++) {
            int pixel =  (int) (255 * tensor.getDouble(i));
            bi.getRaster().setSample(i % width, i / height, 0, pixel);
        }
        ImageIcon orig = new ImageIcon(bi);
        Image imageScaled = orig.getImage().getScaledInstance(width, height, Image.SCALE_REPLICATE);

        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
}